{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoProcessor\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.ndimage import zoom\n",
    "from huggingface_hub import snapshot_download\n",
    "import sys, os\n",
    "repo_path = snapshot_download(\"amildravid4292/llava-llama-3-8b-test-time-registers\") \n",
    "sys.path.insert(0, repo_path)\n",
    "from modeling_custom_llava import LlavaRegistersForConditionalGeneration\n",
    "\n",
    "device = \"cuda:0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# language model attention capture\n",
    "class AttentionCaptureModel(LlavaRegistersForConditionalGeneration):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.captured_attentions = None\n",
    "\n",
    "    def forward(self, *args, **kwargs):\n",
    "        # Capture the attention weights\n",
    "        output = super().forward(*args, **kwargs)\n",
    "        self.captured_attentions = output.attentions\n",
    "        return output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AttentionCaptureModel.from_pretrained(\n",
    "    \"xtuner/llava-llama-3-8b-v1_1-transformers\", \n",
    "    torch_dtype=torch.float16,\n",
    "    output_attentions=True\n",
    ").to(device)\n",
    "# use original processor\n",
    "processor = AutoProcessor.from_pretrained(\"xtuner/llava-llama-3-8b-v1_1-transformers\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hook output of vision model\n",
    "patches = {}\n",
    "def hook_output_patch(module, input, output):\n",
    "    patches[-1] = output\n",
    "\n",
    "hook_handle = model.vision_tower.vision_model.encoder.layers[-1].register_forward_hook(hook_output_patch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = (\"<|start_header_id|>user<|end_header_id|>\\n\\n<image>\\nHow many tennis balls are in the dog's mouth? Use one word.<|eot_id|>\"\n",
    "          \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\")\n",
    "\n",
    "# Load image\n",
    "image_path = \"images/dog_img.webp\"\n",
    "raw_image = Image.open(image_path)\n",
    "\n",
    "inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float16)\n",
    "\n",
    "# use original model without test-time register\n",
    "with torch.no_grad():\n",
    "    output = model.generate(**inputs, max_new_tokens=1, do_sample=False, extra_tokens=0, neuron_dict=None)\n",
    "\n",
    "tokenizer = processor.tokenizer\n",
    "decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)\n",
    "print(\"Decoded output:\", decoded_output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_norms = torch.norm(patches[-1][0].float().squeeze(0), dim=-1).detach().cpu().numpy()\n",
    "hook_handle.remove()\n",
    "\n",
    "plt.axis('off')\n",
    "plt.suptitle(\"Output Patch Norm\", fontsize=25)\n",
    "\n",
    "im = plt.imshow(patch_norms[1:].reshape(24, 24))\n",
    "plt.colorbar(im)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atts = torch.cat(model.captured_attentions).float()\n",
    "# visualize attention from answer to visual tokens\n",
    "im = plt.imshow(atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24))\n",
    "plt.axis(\"off\")\n",
    "plt.suptitle(\"Mean Attention Map for Answer Token \", fontsize = 20)\n",
    "plt.tight_layout()\n",
    "plt.colorbar(im)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atts = torch.cat(model.captured_attentions).float()\n",
    "atts = atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24)\n",
    "image = inputs[\"pixel_values\"][0].permute(1,2,0).float().cpu()\n",
    "scale_factor = 336/24\n",
    "heatmap_upsampled = zoom(atts, scale_factor, order=1)  # bilinear interpolation\n",
    "\n",
    "# Create the overlay\n",
    "fig, ax = plt.subplots(figsize=(10, 10))\n",
    "ax.imshow(image, cmap='gray')  # Show original image\n",
    "ax.imshow(heatmap_upsampled, alpha=0.5, cmap='jet')  # Overlay heatmap with transparency\n",
    "ax.axis('off')\n",
    "\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reset hook\n",
    "patches = {}\n",
    "hook_handle = model.vision_tower.vision_model.encoder.layers[-1].register_forward_hook(hook_output_patch)\n",
    "\n",
    "prompt = (\"<|start_header_id|>user<|end_header_id|>\\n\\n<image>\\nHow many tennis balls are in the dog's mouth? Use one word.<|eot_id|>\"\n",
    "          \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\")\n",
    "\n",
    "# Load image\n",
    "image_path = \"images/dog_img.webp\"\n",
    "\n",
    "raw_image = Image.open(image_path)\n",
    "\n",
    "\n",
    "\n",
    "inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float16)\n",
    "\n",
    "# default uses test-time register\n",
    "with torch.no_grad():\n",
    "    output = model.generate(**inputs, max_new_tokens=1, do_sample=False)\n",
    "\n",
    "tokenizer = processor.tokenizer\n",
    "decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)\n",
    "print(\"Decoded output:\", decoded_output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_norms = torch.norm(patches[-1][0].float().squeeze(0), dim=-1).detach().cpu().numpy()\n",
    "hook_handle.remove()\n",
    "\n",
    "plt.axis('off')\n",
    "plt.suptitle(\"Output Patch Norm\", fontsize=25)\n",
    "\n",
    "im = plt.imshow(patch_norms[1:-1].reshape(24, 24))\n",
    "plt.colorbar(im)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atts = torch.cat(model.captured_attentions).float()\n",
    "# visualize attention from answer to visual tokens\n",
    "im = plt.imshow(atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24))\n",
    "plt.axis(\"off\")\n",
    "plt.suptitle(\"Mean Attention Map for Answer Token \", fontsize = 20)\n",
    "plt.tight_layout()\n",
    "plt.colorbar(im)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "atts = torch.cat(model.captured_attentions).float()\n",
    "atts = atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24)\n",
    "image = inputs[\"pixel_values\"][0].permute(1,2,0).float().cpu()\n",
    "scale_factor = 336/24\n",
    "heatmap_upsampled = zoom(atts, scale_factor, order=1)  # bilinear interpolation\n",
    "\n",
    "# Create the overlay\n",
    "fig, ax = plt.subplots(figsize=(10, 10))\n",
    "ax.imshow(image, cmap='gray')  # Show original image\n",
    "ax.imshow(heatmap_upsampled, alpha=0.5, cmap='jet')  # Overlay heatmap with transparency\n",
    "ax.axis('off')\n",
    "\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llava2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
